各种Normalization

Batch/Layer/Instance/Weight Normalization

从使用最广泛的batch norm说起。

Batch Normalization

batch norm是为了解决网络中各个层的输入数据的分布不稳定的问题( (internal) covariate shift) ),因为若此的话各个层要不停地去适应新的数据分布。

设网络中的某一层的输入\(x=(x^{(1)},…,x^{(k)})\)是d维的,那么按如下的方式normalize它: \[ \hat x^{(k)} = \frac{x^{(k)}-\mathrm{E}[x^{(k)}]}{\sqrt{\mathrm{Var}[x^{(k)}]}} \]

为了让网络有恢复数据scale的能力,即能学习到恒等映射,还要对\(\hat x^{(k)}\)再做一次变换: \[ y^{(k)}=\gamma^{(k)}\hat x^{(k)}+\beta^{(k)} \] 其中\(\gamma^{(k)}\)\(\beta^{(k)}\)是待学习参数。

在实际的训练中,我们不太可能拿整个数据集来计算期望E和方差Var,所以实际上E和Var是在每个batch上做。

在inference阶段,

we want the output to depend only on the input, deterministically.

所以此时的E和Var是针对整个训练集得到的结果。在实际使用中常采用移动平均的方法计算population mean 和 population variance。

Weight normalization

网络中的一个神经元做的事就是对接收到的按权重求和的输入特征做非线性变换 \[ y = \phi(\mathbf{w}\cdot\mathbf{x}+b), \] (上一节的\(x\) 是这一节的\(\mathbf{w}\cdot\mathbf{x}+b\))。论文的作者提出\(\mathbf{w}\)可以重参数化为如下形式 \[ \mathbf{w} = \frac{g}{||\mathbf{v}||}\mathbf{v} \] 其中\(g\)是一个待学习标量,\(\mathbf{v}\)是与\(\mathbf{w}\) 同维度的矢量。注意到\(\mathbf{w}\) 有确定的欧式范数\(g\)

这样参数化之后,损失\(L\)\(\mathbf{w}\)的梯度可写成: \[ \nabla_g L = \frac{\nabla_{\mathrm{w}}L\cdot\mathbf{v}}{||\mathbf{v}||},\quad\nabla_{\mathbf{v}}L = \frac{g}{||\mathbf{v}||}\nabla_{\mathbf{w}}L-\frac{g\nabla_g L}{||\mathbf{v}||^2}\mathbf{v} \label{gradient1} \]

其中\(\nabla_\mathbf{w}L​\)是通常计算方式下的\(L​\)\(\mathbf{w}​\)的梯度。

上式中的第二式还可写成 \[ \nabla_{\mathbf{v}}L = \frac{g}{||\mathbf{v}||}M_{\mathbf{w}}\nabla_{\mathbf{w}}L,\quad \mathrm{with \quad M_{\mathbf{w}}}=I-\frac{\mathbf{w}\mathbf{w}^T}{||\mathbf{w}||^2} \] \(M_{\mathbf{w}}\)可看成投影到\(\mathbf{w}\)正交方向上的投影矩阵。

作者的一通分析后指出,采用weight normalization可以使得优化过程对learning rate不那么挑剔。

初始化问题

与batch normalization不同,weight normalization需要合适地选取初始化方案。

一种方法是,先让网络过一个mini-batch,

  1. 在每个neuron的输入\(\mathbf{x}\)上计算\(t=\frac{\mathbf{v}\cdot\mathbf{x}}{||\mathbf{v}||}\)

  2. neuron的输出\(y=\phi\left(\frac{t-\mu[t]}{\sigma[t]}\right)\) , \(\mu\)\(\sigma\) 是minibatch上的平均值和标准差。

  3. 初始化\(g\leftarrow\frac{1}{\sigma[t]},\quad b\leftarrow\frac{-\mu[t]}{\sigma[t]}\)

实际上在不采用weight normalization的网络上也可采用这样的初始化方案。然而无法应用到RNN。

Mean-only Batch Normalization

论文中用这种方法得到了CIFAR-10上不使用data augmentation下的state-of-the-art。

对每个神经元计算

  1. \(t=\mathbf{w}\cdot\mathbf{x}\), 其中\(\mathbf{w}\)仍按上文的方式参数化
  2. \(\tilde t=t-\mu[t]+b\)
  3. \(y = \phi(\tilde t)=\phi(t-\mu[t]+b)\)

损失\(L\)\(t\)的梯度为 \[ \nabla_t L = \nabla_{\tilde t} L -\mu[\nabla_{\tilde t}L] \] 与完整的batch normalization相比计算量较小。

引用一下原文

Mean-only batch normalization thus has the effect of centering the gradients that are backpropagated.

Layer normalization

在一个layer的所有input上做normalization,仅在样本内部做,和batch内的其他样本无关: \[ \mu = \frac{1}{k}\sum_{i=1}^{k}x^{(i)} \\ \sigma = \sqrt{\frac{1}{k}\sum_{i=1}^{k}(x^{(k)}-\mu)} \] 注意到这里对所有input的求和并不包括bias。Layer normalization 配合RNN使用最佳。

特点:

  1. 不依赖batch大小。
  2. 可直接用于RNN。
  3. 训练和测试的计算过程是一样的。

权重和数据变换下的不变性

三种normalization都可以总结为对neuron的summed inputs \(x\)\(\sigma\)\(\mu\)做normalization后喂给激活函数 \[ y=\phi(\frac{g}{\sigma}(x-\mu)+b) \] \(g\)\(b\)是学习参数。对于weight normalization来说 \(\mu=0, \sigma=||\mathbf{w}||\)

三种normalization在不同变换下的性质如下,摘自文献[3], W表示权重矩阵,w表示权重向量:

W re-scaling W re-centering w re-scaling dataset re-scaling dataset re-centering single sample re-scaling
BN Invariant No Invariant Invariant Invariant No
WN Invariant No Invariant No No No
LN Invariant Invariant No Invariant No Invariant

Instance normalization

Instance normalization主要用在风格转移中,用于提升生成图片质量。论文的作者建议,在生成网络中,用如下的 instance normalization 替换batch normalization: \[ y_{tijk}=\frac{x_{tijk}-\mu_{ti}}{\sqrt{\sigma^2_{ti}+\epsilon}},\ \mu_{ti}=\frac{1}{HW}\sum^W_{l=1}\sum^H_{m=1}x_{tilm},\ \sigma^2_{ti}=\frac{1}{HW}\sum^W_{l=1}\sum^H_{m=1}(x_{tilm}-\mu_{ti})^2 \] \(x\)的四个下标分别表示index in batch, channel, width and hight.

与batch normalization不同的是,instance normalization在训练时与在测试时的使用没有差别。

两点值得比较的地方:

  1. 若采用batch normalization , 那么\(\mu\)\(\sigma\)的计算式中会多出对\(t\)的求和,即,要跨样本算平均。
  2. 若采用layer normalization, 那么\(\mu\)\(\sigma\)的计算式中会多出对\(i\)的求和,即,要跨通道算平均。

参考文献

[1]. https://arxiv.org/abs/1502.03167 Batch Normalization

[2]. https://arxiv.org/abs/1602.07868 Weight Normalization

[3]. https://arxiv.org/abs/1607.06450 Layer Normalization

[4]. https://arxiv.org/abs/1607.08022 Instance Normalization